{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "f833523c",
      "metadata": {
        "id": "f833523c"
      },
      "source": [
        "**This is example of how to trace model with jit and export it to the onnx**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2d832d65",
      "metadata": {
        "id": "2d832d65"
      },
      "outputs": [],
      "source": [
        "!pip install onnx\n",
        "!pip install onnxruntime\n",
        "!pip install git+https://github.com/Denys88/rl_games\n",
        "!pip install envpool\n",
        "!pip install gym\n",
        "!pip install -U colabgymrender"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "yE40EhNFVszf",
      "metadata": {
        "id": "yE40EhNFVszf"
      },
      "outputs": [],
      "source": [
        "from rl_games.torch_runner import Runner\n",
        "import os\n",
        "import yaml\n",
        "import torch\n",
        "import matplotlib.pyplot as plt\n",
        "import gym\n",
        "from IPython import display\n",
        "import numpy as np\n",
        "import onnx\n",
        "import onnxruntime as ort\n",
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "cf09dab6",
      "metadata": {
        "id": "cf09dab6"
      },
      "outputs": [],
      "source": [
        "!nvidia-smi -L"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2enRAdp8WrJV",
      "metadata": {
        "id": "2enRAdp8WrJV"
      },
      "outputs": [],
      "source": [
        "%load_ext tensorboard"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "JGE4eeUCWsss",
      "metadata": {
        "id": "JGE4eeUCWsss"
      },
      "outputs": [],
      "source": [
        "%tensorboard --logdir 'runs/'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "df8682b3",
      "metadata": {
        "id": "df8682b3"
      },
      "outputs": [],
      "source": [
        "config = {'params': {'algo': {'name': 'a2c_discrete'},\n",
        "  'config': {\n",
        "   'clip_value': False,\n",
        "   'critic_coef': 2,\n",
        "   'e_clip': 0.2,\n",
        "   'entropy_coef': 0.01,\n",
        "   'env_config': {'env_name': 'CartPole-v1', 'seed': 5},\n",
        "   'env_name': 'envpool',\n",
        "   'full_experiment_name' : 'cartpole_onnx',\n",
        "   'save_best_after' : 20,\n",
        "   'gamma': 0.99,\n",
        "   'grad_norm': 1.0,\n",
        "   'horizon_length': 32,\n",
        "   'kl_threshold': 0.008,\n",
        "   'learning_rate': '3e-4',\n",
        "   'lr_schedule': 'adaptive',\n",
        "   'max_epochs': 100,\n",
        "   'mini_epochs': 5,\n",
        "   'minibatch_size': 1024,\n",
        "   'name': 'cartpole',\n",
        "   'normalize_advantage': True,\n",
        "   'normalize_input': True,\n",
        "   'normalize_value': True,\n",
        "   'num_actors': 64,\n",
        "   'player': {'render': True},\n",
        "   'ppo': True,\n",
        "   'reward_shaper': {'scale_value': 0.1},\n",
        "   'schedule_type': 'standard',\n",
        "   'score_to_win': 20000,\n",
        "   'tau': 0.95,\n",
        "   'truncate_grads': True,\n",
        "   'use_smooth_clamp': False,\n",
        "   'value_bootstrap': True},\n",
        "  'model': {'name': 'discrete_a2c'},\n",
        "  'network': {'mlp': {'activation': 'elu',\n",
        "    'initializer': {'name': 'default'},\n",
        "    'units': [32, 32]},\n",
        "   'name': 'actor_critic',\n",
        "   'separate': False,\n",
        "   'space': {'discrete': {},\n",
        "  'seed': 5}}\n",
        "  }\n",
        "}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "c91c090f",
      "metadata": {
        "id": "c91c090f"
      },
      "outputs": [],
      "source": [
        "runner = Runner()\n",
        "runner.load(config)\n",
        "runner.run({\n",
        "    'train': True,\n",
        "})"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "bc130c78",
      "metadata": {
        "id": "bc130c78"
      },
      "outputs": [],
      "source": [
        "class ModelWrapper(torch.nn.Module):\n",
        "    '''\n",
        "    Main idea is to ignore outputs which we don't need from model\n",
        "    '''\n",
        "    def __init__(self, model):\n",
        "        torch.nn.Module.__init__(self)\n",
        "        self._model = model\n",
        "        \n",
        "        \n",
        "    def forward(self,input_dict):\n",
        "        input_dict['obs'] = self._model.norm_obs(input_dict['obs'])\n",
        "        '''\n",
        "        just model export doesn't work. Looks like onnx issue with torch distributions\n",
        "        thats why we are exporting only neural network\n",
        "        '''\n",
        "        #print(input_dict)\n",
        "        #output_dict = self._model.a2c_network(input_dict)\n",
        "        #input_dict['is_train'] = False\n",
        "        #return output_dict['logits'], output_dict['values']\n",
        "        return self._model.a2c_network(input_dict)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "40268292",
      "metadata": {
        "id": "40268292"
      },
      "outputs": [],
      "source": [
        "agent = runner.create_player()\n",
        "agent.restore('runs/cartpole_onnx/nn/cartpole.pth')\n",
        "\n",
        "import rl_games.algos_torch.flatten as flatten\n",
        "inputs = {\n",
        "    'obs' : torch.zeros((1,) + agent.obs_shape).to(agent.device),\n",
        "    'rnn_states' : agent.states,\n",
        "}\n",
        "\n",
        "with torch.no_grad():\n",
        "    adapter = flatten.TracingAdapter(ModelWrapper(agent.model), inputs,allow_non_tensor=True)\n",
        "    traced = torch.jit.trace(adapter, adapter.flattened_inputs,check_trace=False)\n",
        "    flattened_outputs = traced(*adapter.flattened_inputs)\n",
        "    print(flattened_outputs)\n",
        "    \n",
        "torch.onnx.export(traced, *adapter.flattened_inputs, \"cartpole.onnx\", verbose=True, input_names=['obs'], output_names=['logits', 'value'])\n",
        "\n",
        "onnx_model = onnx.load(\"cartpole.onnx\")\n",
        "\n",
        "# Check that the model is well formed\n",
        "onnx.checker.check_model(onnx_model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "09c2e424",
      "metadata": {
        "id": "09c2e424"
      },
      "outputs": [],
      "source": [
        "ort_model = ort.InferenceSession(\"cartpole.onnx\")\n",
        "\n",
        "outputs = ort_model.run(\n",
        "    None,\n",
        "    {\"obs\": np.zeros((1, 4)).astype(np.float32)},\n",
        ")\n",
        "print(outputs)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "04a41060",
      "metadata": {
        "id": "04a41060"
      },
      "outputs": [],
      "source": [
        "os.environ[\"SDL_VIDEODRIVER\"] = \"dummy\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a32c50a1",
      "metadata": {
        "id": "a32c50a1"
      },
      "outputs": [],
      "source": [
        "is_done = False\n",
        "\n",
        "# using regular openai gym to render\n",
        "env = gym.make('CartPole-v1')\n",
        "obs = env.reset()\n",
        "prev_screen = env.render(mode='rgb_array')\n",
        "plt.imshow(prev_screen)\n",
        "total_reward = 0\n",
        "num_steps = 0\n",
        "\n",
        "while not is_done:\n",
        "    outputs = ort_model.run(None, {\"obs\": np.expand_dims(obs, axis=0).astype(np.float32)},)\n",
        "\n",
        "    action = np.argmax(outputs[0])\n",
        "    obs, reward, done, info = env.step(action)\n",
        "    total_reward += reward\n",
        "    num_steps += 1\n",
        "    is_done = done\n",
        "\n",
        "    screen = env.render(mode='rgb_array')\n",
        "    plt.imshow(screen)\n",
        "    display.display(plt.gcf())    \n",
        "    display.clear_output(wait=True)\n",
        "\n",
        "print(total_reward, num_steps)\n",
        "display.clear_output(wait=True)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "warp39",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.13"
    },
    "vscode": {
      "interpreter": {
        "hash": "20dffcfa027a5ca97c32e660f6348a5dd89a4a8771672beb12fd55712d57511e"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
